/**
 * Copyright (c) 2024 Huawei Technologies Co., Ltd.
 * This file is a part of the CANN Open Software.
 * Licensed under CANN Open Software License Agreement Version 1.0 (the
 * "License"). Please refer to the License for details. You may not use this
 * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN
 * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS
 * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository
 * for the full text of the License.
 */

#ifndef EXAMPLES_NORMALIZATION_LAYERNORM_GRAD_CUSTOM_H
#define EXAMPLES_NORMALIZATION_LAYERNORM_GRAD_CUSTOM_H
#include "kernel_operator.h"

namespace MyCustomKernel {
struct VecTiling {
  LayerNormGradTiling layernormGradTilingData;
  float epsilon = 0;
};

template <bool isReuseSource = false> class KernelLayernormGrad {
public:
  __aicore__ inline KernelLayernormGrad() {}
  __aicore__ inline void Init(GM_ADDR inputXGm, GM_ADDR inputDyGm,
                              GM_ADDR inputVarianceGm, GM_ADDR inputMeanGm,
                              GM_ADDR inputGammaGm, GM_ADDR outputPdXGm,
                              GM_ADDR resForGammaGm, VecTiling tilingData) {
    this->epsilon = tilingData.epsilon;
    tiling_ = tilingData.layernormGradTilingData;
    this->bLength = tiling_.bLength;
    this->sLength = tiling_.sLength;
    this->hLength = tiling_.hLength;

    bshLength = bLength * sLength * hLength;
    bsLength = bLength * sLength;

    inputXGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ float *>(inputXGm),
                                 bshLength);
    inputDyGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ float *>(inputDyGm),
                                  bshLength);
    inputVarianceGlobal.SetGlobalBuffer(
        reinterpret_cast<__gm__ float *>(inputVarianceGm), bsLength);
    inputMeanGlobal.SetGlobalBuffer(
        reinterpret_cast<__gm__ float *>(inputMeanGm), bsLength);
    inputGammaGlobal.SetGlobalBuffer(
        reinterpret_cast<__gm__ float *>(inputGammaGm), hLength);

    outputPdXGlobal.SetGlobalBuffer(
        reinterpret_cast<__gm__ float *>(outputPdXGm), bshLength);
    outputResForGammaGlobal.SetGlobalBuffer(
        reinterpret_cast<__gm__ float *>(resForGammaGm), bshLength);

    pipe.InitBuffer(inQueueX, 1, sizeof(float) * bshLength);
    pipe.InitBuffer(inQueueDy, 1, sizeof(float) * bshLength);
    pipe.InitBuffer(inQueueVariance, 1, sizeof(float) * bsLength);
    pipe.InitBuffer(inQueueMean, 1, sizeof(float) * bsLength);
    pipe.InitBuffer(inQueueGamma, 1, sizeof(float) * hLength);
    pipe.InitBuffer(outQueuePdX, 1, sizeof(float) * bshLength);
    pipe.InitBuffer(outQueueResForGamma, 1, sizeof(float) * bshLength);
  }
  __aicore__ inline void Process() {
    CopyIn();
    Compute();
    CopyOut();
  }

private:
  __aicore__ inline void CopyIn() {
    AscendC::LocalTensor<float> inputXLocal = inQueueX.AllocTensor<float>();
    AscendC::LocalTensor<float> inputDyLocal = inQueueDy.AllocTensor<float>();
    AscendC::LocalTensor<float> inputVarianceLocal =
        inQueueVariance.AllocTensor<float>();
    AscendC::LocalTensor<float> inputMeanLocal =
        inQueueMean.AllocTensor<float>();
    AscendC::LocalTensor<float> inputGammaLocal =
        inQueueGamma.AllocTensor<float>();

    AscendC::DataCopy(inputXLocal, inputXGlobal, bshLength);
    AscendC::DataCopy(inputDyLocal, inputDyGlobal, bshLength);
    AscendC::DataCopy(inputVarianceLocal, inputVarianceGlobal, bsLength);
    AscendC::DataCopy(inputMeanLocal, inputMeanGlobal, bsLength);
    AscendC::DataCopy(inputGammaLocal, inputGammaGlobal, hLength);

    inQueueX.EnQue(inputXLocal);
    inQueueDy.EnQue(inputDyLocal);
    inQueueVariance.EnQue(inputVarianceLocal);
    inQueueMean.EnQue(inputMeanLocal);
    inQueueGamma.EnQue(inputGammaLocal);
  }
  __aicore__ inline void Compute() {
    AscendC::LocalTensor<float> inputXLocal = inQueueX.DeQue<float>();
    AscendC::LocalTensor<float> inputDyLocal = inQueueDy.DeQue<float>();
    AscendC::LocalTensor<float> inputVarianceLocal =
        inQueueVariance.DeQue<float>();
    AscendC::LocalTensor<float> inputMeanLocal = inQueueMean.DeQue<float>();
    AscendC::LocalTensor<float> inputGammaLocal = inQueueGamma.DeQue<float>();
    AscendC::LocalTensor<float> outputPdXLocal =
        outQueuePdX.AllocTensor<float>();
    AscendC::LocalTensor<float> outputResForGammaLocal =
        outQueueResForGamma.AllocTensor<float>();
    AscendC::LayerNormGrad<float, isReuseSource>(
        outputPdXLocal, outputResForGammaLocal, inputDyLocal, inputXLocal,
        inputVarianceLocal, inputMeanLocal, inputGammaLocal, (float)epsilon,
        tiling_);

    outQueuePdX.EnQue(outputPdXLocal);
    outQueueResForGamma.EnQue(outputResForGammaLocal);

    inQueueX.FreeTensor(inputXLocal);
    inQueueDy.FreeTensor(inputDyLocal);
    inQueueVariance.FreeTensor(inputVarianceLocal);
    inQueueMean.FreeTensor(inputMeanLocal);
    inQueueGamma.FreeTensor(inputGammaLocal);
  }
  __aicore__ inline void CopyOut() {
    AscendC::LocalTensor<float> outputPdXLocal = outQueuePdX.DeQue<float>();
    AscendC::LocalTensor<float> outputResForGammaLocal =
        outQueueResForGamma.DeQue<float>();

    AscendC::DataCopy(outputPdXGlobal, outputPdXLocal, bshLength);
    AscendC::DataCopy(outputResForGammaGlobal, outputResForGammaLocal,
                      bshLength);

    outQueuePdX.FreeTensor(outputPdXLocal);
    outQueueResForGamma.FreeTensor(outputResForGammaLocal);
  }

private:
  AscendC::GlobalTensor<float> inputXGlobal;
  AscendC::GlobalTensor<float> inputDyGlobal;
  AscendC::GlobalTensor<float> inputVarianceGlobal;
  AscendC::GlobalTensor<float> inputMeanGlobal;
  AscendC::GlobalTensor<float> inputGammaGlobal;
  AscendC::GlobalTensor<float> outputPdXGlobal;
  AscendC::GlobalTensor<float> outputResForGammaGlobal;

  AscendC::TPipe pipe;
  AscendC::TQue<AscendC::TPosition::VECIN, 1> inQueueX;
  AscendC::TQue<AscendC::TPosition::VECIN, 1> inQueueDy;
  AscendC::TQue<AscendC::TPosition::VECIN, 1> inQueueVariance;
  AscendC::TQue<AscendC::TPosition::VECIN, 1> inQueueMean;
  AscendC::TQue<AscendC::TPosition::VECIN, 1> inQueueGamma;

  AscendC::TQue<AscendC::TPosition::VECOUT, 1> outQueuePdX;
  AscendC::TQue<AscendC::TPosition::VECOUT, 1> outQueueResForGamma;

  uint32_t bLength;
  uint32_t sLength;
  uint32_t hLength;
  float epsilon;
  LayerNormGradTiling tiling_;

  uint32_t bshLength;
  uint32_t bsLength;
};
} // namespace MyCustomKernel
#endif // EXAMPLES_NORMALIZATION_LAYERNORM_GRAD_CUSTOM_H